/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {

using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

// Configuring a distributed TPU system is achieved by running
// the following Ops:
//
// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each
// host. This is needed in case the system had previously been
// configured. It returns, for each host, the number of TPU chips on
// the host.
//
// 2 Run _ConfigureDistributedTPU on TPU_SYSTEM. Takes as input the
// number of chips on each host. Validates that all hosts have the
// same number of chips, and that the chips are consistent with the
// topology set by flags. Has a single output which is a proto
// describing the requested system configuration, which is sent to all
// hosts.
//
// 3 Run _InitializeHostForDistributedTPU on the CPU of each host,
// taking as input the output from ConfigureDistributedTPU. Has a
// single Tensor output which is a vector of int32 indicating, for
// each TPU on the host, what its global TPU system id is.
//
// 4 Run _WaitForDistributedTPU on TPU_SYSTEM, taking as input the
// outputs from all the _InitializeHostForDistributedTPU
// Ops. _WaitForDistributedTPU has an attr host_specs which is a
// vector<string> giving the partial device spec for each host. These
// partial specs are combined in the Op with the outputs from the host
// initialization Ops to construct a mapping from full TPU device
// specs to global TPU ids. Has a single Tensor output which is a
// matrix of int32 indicating, for each host (outer dimension) and for
// each TPU on the host (inner dimension) what that TPU's global id
// is. _WaitForDistributedTPU also waits for the TPU distributed
// system to initialize fully, which may take several minutes for a
// large system.
//
// 5 Run _SetGlobalTPUArray on the CPU of each host, taking as input
// the output from _WaitForDistributedTPU. This Op tells each host the
// global Id of every TPU on every host.
//
// Most user code works by placing the ConfigureDistributedTPU Op on
// the desired TPU_SYSTEM device, and a graph rewrite replaces it by
// the subgraph described above.
//
//
// A distributed TPU system can be cleanly shut down by running
// the following Ops:
//
// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each
// host.
//
// 2 Run _ShutdownDistributedTPU on the TPU_SYSTEM where
// _ConfigureDistributedTPU was run. The Op will return an error if no
// system is configured.
//
//
// Most user code works by placing the ShutdownDistributedTPU Op on
// the desired TPU_SYSTEM device, and a graph rewrite replaces it by
// the subgraph described above.

REGISTER_OP("_ConfigureDistributedTPU")
    .Input("inputs: N * int32")
    .Output("output: string")
    .Attr("N: int >= 1")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle input;
      // Validate that all the inputs are scalars.
      for (int i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input));
      }
      c->set_output(0, c->Scalar());
      return Status::OK();
    })
    .Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU
system.

inputs: A scalar tensor for each host indicating how many TPU chips
there are on the host.
output: A tensor containing a TPUHostConfiguration proto serialized to
a string, containing the information necessary to initialize the chips
in a host.
)doc");

REGISTER_OP("_WaitForDistributedTPU")
    .Input("inputs: N * int32")
    .Output("global_tpu_array: int32")
    .Attr("host_specs: list(string)")
    .Attr("startup_timeout_sec: int = 20")
    .Attr("N: int")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle input;
      // Validate that all the inputs have the same vector shape.
      for (int i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &input));
      }
      c->set_output(0, c->UnknownShapeOfRank(2));
      return ::tensorflow::Status::OK();
    })
    .Doc(R"doc(
An op that blocks execution until a distributed TPU system has
started up. This Op must be run on the same TPU_SYSTEM device as
_ConfigureDistributedTPU, and takes an inputs the outputs from the
_InitializeHostForDistributedTPU Ops.

inputs: For each initialized host, a vector giving the global TPU id
of each TPU on the host.
global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
host_specs: For each initialized host, the partial device specification
indicating job, replica, and task. Combining this spec with
'/device:TPU:k' gives the full device name of the k'th TPU on the
host.
startup_timeout_sec: The number of seconds to wait for the TPU system
to stabilize.
)doc");

REGISTER_OP("_SetGlobalTPUArray")
    .Input("global_tpu_array: int32")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
      return ::tensorflow::Status::OK();
    })
    .Doc(R"doc(
An op that informs a host of the global ids of all the of TPUs in the
system.

global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
)doc");

REGISTER_OP("_ShutdownDistributedTPU")
    .SetIsStateful()
    .SetShapeFn(shape_inference::UnknownShape)
    .Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running. This Op must be run on the same
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
to start the system, and must be run only after
_DisconnectHostFromDistributedTPUSystem has completed on every host in
the system.
)doc");

REGISTER_OP("_InitializeHostForDistributedTPU")
    .Input("input: string")
    .Output("tpu_ids: int32")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input));
      c->set_output(0, c->Vector(c->UnknownDim()));
      return ::tensorflow::Status::OK();
    })
    .Doc(R"doc(
An op that connects each chip on the host to a centralized UberDriver to allow
them to operate as a distributed system with chips in other hosts.

input: A string containing the address of the UberDriver to connect to.
tpu_ids: A vector containing the global TPU id of each TPU on the host.
)doc");

REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
    .Output("number_of_tpu_chips: int32")
    .SetIsStateful()
    .SetShapeFn(shape_inference::UnknownShape)
    .Doc(R"doc(
An op that disconnects the TPUs on a host from a running distributed
TPU system.

number_of_tpu_chips: A scalar tensor containing the number of TPU
chips on the host.
)doc");

REGISTER_OP("ConfigureDistributedTPU")
    .Output("global_tpu_array: int32")
    .Attr("embedding_config: string = ''")
    .SetIsStateful()
    .SetShapeFn(shape_inference::UnknownShape)
    .Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU
system.

global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
embedding_config: Internal use.
)doc");

REGISTER_OP("ShutdownDistributedTPU")
    .SetIsStateful()
    .SetShapeFn(shape_inference::UnknownShape)
    .Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running.
)doc");

}  // namespace tensorflow
